#include "ShadowShader.h"

ShadowShader::ShadowShader(ID3D11Device* device, HWND hwnd) : BowerShader(device, hwnd)
{
	initShader(L"shadow_vs.cso", L"shadow_ps.cso");
}

ShadowShader::~ShadowShader()
{
	//Release the buffers
	if (matrixBuffer)
	{
		matrixBuffer->Release();
		matrixBuffer = 0;
	}

	if (lightBuffer)
	{
		lightBuffer->Release();
		lightBuffer = 0;
	}

	//Release the samplers
	if (sampleState)
	{
		sampleState->Release();
		sampleState = 0;
	}

	if (sampleStateShadow)
	{
		sampleStateShadow->Release();
		sampleStateShadow = 0;
	}

	if (layout)
	{
		layout->Release();
		layout = 0;
	}
	
	BaseShader::~BaseShader();
}


void ShadowShader::initShader(const wchar_t* vsFilename, const wchar_t* psFilename)
{
	loadVertexShader(vsFilename);
	loadPixelShader(psFilename);

	//Setup the matrix buffer
	D3D11_BUFFER_DESC matrixBufferDesc;
	matrixBufferDesc.Usage = D3D11_USAGE_DYNAMIC;
	matrixBufferDesc.ByteWidth = sizeof(ShadowMatrixBufferType);
	matrixBufferDesc.BindFlags = D3D11_BIND_CONSTANT_BUFFER;
	matrixBufferDesc.CPUAccessFlags = D3D11_CPU_ACCESS_WRITE;
	matrixBufferDesc.MiscFlags = 0;
	matrixBufferDesc.StructureByteStride = 0;
	renderer->CreateBuffer(&matrixBufferDesc, NULL, &matrixBuffer);

	//Setup the light buffer
	D3D11_BUFFER_DESC lightBufferDesc;
	lightBufferDesc.Usage = D3D11_USAGE_DYNAMIC;
	lightBufferDesc.ByteWidth = sizeof(LightBufferType);
	lightBufferDesc.BindFlags = D3D11_BIND_CONSTANT_BUFFER;
	lightBufferDesc.CPUAccessFlags = D3D11_CPU_ACCESS_WRITE;
	lightBufferDesc.MiscFlags = 0;
	lightBufferDesc.StructureByteStride = 0;
	renderer->CreateBuffer(&lightBufferDesc, NULL, &lightBuffer);

	//Setup the base sampler
	D3D11_SAMPLER_DESC samplerDesc;
	samplerDesc.Filter = D3D11_FILTER_MIN_MAG_MIP_LINEAR;
	samplerDesc.AddressU = D3D11_TEXTURE_ADDRESS_WRAP;
	samplerDesc.AddressV = D3D11_TEXTURE_ADDRESS_WRAP;
	samplerDesc.AddressW = D3D11_TEXTURE_ADDRESS_WRAP;
	samplerDesc.MipLODBias = 0.0f;
	samplerDesc.MaxAnisotropy = 1;
	samplerDesc.ComparisonFunc = D3D11_COMPARISON_ALWAYS;
	samplerDesc.BorderColor[0] = 0;
	samplerDesc.BorderColor[1] = 0;
	samplerDesc.BorderColor[2] = 0;
	samplerDesc.BorderColor[3] = 0;
	samplerDesc.MinLOD = 0;
	samplerDesc.MaxLOD = D3D11_FLOAT32_MAX;
	renderer->CreateSamplerState(&samplerDesc, &sampleState);

	//Setup the shadow sampler
	samplerDesc.Filter = D3D11_FILTER_MIN_MAG_MIP_POINT;
	samplerDesc.AddressU = D3D11_TEXTURE_ADDRESS_BORDER;
	samplerDesc.AddressV = D3D11_TEXTURE_ADDRESS_BORDER;
	samplerDesc.AddressW = D3D11_TEXTURE_ADDRESS_BORDER;
	samplerDesc.BorderColor[0] = 1.0f;
	samplerDesc.BorderColor[1] = 1.0f;
	samplerDesc.BorderColor[2] = 1.0f;
	samplerDesc.BorderColor[3] = 1.0f;
	renderer->CreateSamplerState(&samplerDesc, &sampleStateShadow);
}


void ShadowShader::setShaderParameters(ID3D11DeviceContext* deviceContext, const XMMATRIX &worldMatrix, const XMMATRIX &viewMatrix, const XMMATRIX &projectionMatrix, ID3D11ShaderResourceView* texture, ID3D11ShaderResourceView* sunShadowMap, ID3D11ShaderResourceView* spotShadowMap, ID3D11ShaderResourceView* hillShadowMaps[6], float constant, float linear, float quadratic)
{
	D3D11_MAPPED_SUBRESOURCE mappedResource;
	
	XMMATRIX tworld = XMMatrixTranspose(worldMatrix);
	XMMATRIX tview = XMMatrixTranspose(viewMatrix);
	XMMATRIX tproj = XMMatrixTranspose(projectionMatrix);

	//Map the matrix buffer
	deviceContext->Map(matrixBuffer, 0, D3D11_MAP_WRITE_DISCARD, 0, &mappedResource);
	ShadowMatrixBufferType* shadowMatrixPtr = (ShadowMatrixBufferType*)mappedResource.pData;
	shadowMatrixPtr->world = tworld;
	shadowMatrixPtr->view = tview;
	shadowMatrixPtr->projection = tproj;
	shadowMatrixPtr->sunLightView = XMMatrixTranspose(_SunLight->getViewMatrix());
	shadowMatrixPtr->sunLightProjection = XMMatrixTranspose(_SunLight->getOrthoMatrix());

	//Map the light buffer
	deviceContext->Map(lightBuffer, 0, D3D11_MAP_WRITE_DISCARD, 0, &mappedResource);
	LightBufferType* lightPtr = (LightBufferType*)mappedResource.pData;
	lightPtr->sunAmbient = _SunLight->getAmbientColour();
	lightPtr->sunDiffuse = _SunLight->getDiffuseColour();
	lightPtr->sunDirection = _SunLight->getDirection();
	lightPtr->padding = 0.0f;
	lightPtr->spotDiffuse = _SpotLight->getDiffuseColour();
	lightPtr->spotPosition = _SpotLight->getPosition();
	lightPtr->constantFactor = constant;
	lightPtr->linearFactor = linear;
	lightPtr->quadraticFactor = quadratic;
	lightPtr->paddingTwo = XMFLOAT2(0.0f, 0.0f);
	lightPtr->spotDirection = _SpotLight->getDirection();
	lightPtr->paddingThree = 0.0f;
	lightPtr->hillDiffuse = _HillLight->getDiffuseColour();
	lightPtr->hillPosition = _HillLight->getPosition();
	lightPtr->paddingFour = 0.0f;
	deviceContext->Unmap(lightBuffer, 0);
	deviceContext->PSSetConstantBuffers(0, 1, &lightBuffer);

	//Get the view matrix for the spot light
	XMMATRIX spotViewMatrix;
	if (_SpotLight->getDirection().x == 0 && _SpotLight->getDirection().z == 0)
	{
		spotViewMatrix = GetYAxisViewMatrix(_SpotLight);

		if (_SpotLight->getDirection().y < 0.0f)
		{
			spotViewMatrix = -spotViewMatrix;
		}
	}

	else
	{
		_SpotLight->generateViewMatrix();
		spotViewMatrix = _SpotLight->getViewMatrix();
	}

	//Set the values into the mapped pointer
	shadowMatrixPtr->spotLightView = XMMatrixTranspose(spotViewMatrix);
	shadowMatrixPtr->spotLightProjection = XMMatrixTranspose(_SpotLight->getProjectionMatrix());

	//Get the view matrices for the point light
	XMFLOAT3 lightDirections[6] =
	{
		XMFLOAT3(0.0f, 1.0f, 0.0f),
		XMFLOAT3(0.0f, -1.0f, 0.0f),
		XMFLOAT3(1.0f, 0.0f, 0.0f),
		XMFLOAT3(-1.0f, 0.0f, 0.0f),
		XMFLOAT3(0.0f, 0.0f, 1.0f),
		XMFLOAT3(0.0f, 0.0f, -1.0f)
	};

	for (int i = 0; i < 6; i++) 
	{
		_HillLight->setDirection(lightDirections[i].x, lightDirections[i].y, lightDirections[i].z);
		XMMATRIX viewMatrix;

		if (_HillLight->getDirection().x == 0 && _HillLight->getDirection().z == 0)
		{
			viewMatrix = GetYAxisViewMatrix(_HillLight);

			if (_HillLight->getDirection().y < 0.0f)
			{
				viewMatrix = -viewMatrix;
			}
		}

		else
		{
			_HillLight->generateViewMatrix();
			viewMatrix = _HillLight->getViewMatrix();
		}

		XMMATRIX lightViewMatrix = XMMatrixTranspose(viewMatrix);
		XMMATRIX lightProjectionMatrix = XMMatrixTranspose(_HillLight->getProjectionMatrix());

		//Set the values into the mapped pointer
		shadowMatrixPtr->hillLightViews[i] = lightViewMatrix;
		shadowMatrixPtr->hillLightProjections[i] = lightProjectionMatrix;
	}

	deviceContext->Unmap(matrixBuffer, 0);
	deviceContext->VSSetConstantBuffers(0, 1, &matrixBuffer);

	//Set the pixel shader textures and samplers
	deviceContext->PSSetShaderResources(0, 1, &texture);
	deviceContext->PSSetShaderResources(1, 1, &sunShadowMap);
	deviceContext->PSSetShaderResources(2, 1, &spotShadowMap);
	deviceContext->PSSetShaderResources(3, 1, hillShadowMaps);
	deviceContext->PSSetSamplers(0, 1, &sampleState);
	deviceContext->PSSetSamplers(1, 1, &sampleStateShadow);
}

